import numpy as np
from ortools.linear_solver import pywraplp
from Helper import calc_proj_matrix, calc_complement_proj_matrix, calc_mu
from FastL1Basis import getFCT1
from scipy.special import expit

# A = D_w X already computed
def solve_mu_exact(A, solvername="GLOP"):
    solver = pywraplp.Solver.CreateSolver(solvername)
# SetSolverSpecificParametersAsString
    # solver.SetSolverSpecificParamtersAsString('termination_criteria {eps_optimal_absolute: 1e-4, eps_optimal_relative: 1e-4 }')
    solver.SetSolverSpecificParametersAsString(
        'termination_criteria {eps_optimal_absolute: 1e-4 eps_optimal_relative: 1e-4}')
    IminusP = calc_complement_proj_matrix(A)

    if not solver:
        print('ERROR no solver\n')
        return

    n, d = np.shape(A)

    zplus = {}
    zneg = {}

    for j in range(n):
        zplus[j] = solver.NumVar(0, solver.infinity(), "zplus[%i]" % j)
        zneg[j] = solver.NumVar(0, solver.infinity(), "zneg[%i]" % j)

    for i in range(n):
        constraint = solver.RowConstraint(0, 0, "")
        for j in range(n):
            constraint.SetCoefficient(zplus[j], IminusP[i,j])
            constraint.SetCoefficient(zneg[j], -1.0*IminusP[i, j])

# new constraint. with lb = 0 , ub = 100, 0<= c_zplus * zplus + c_zneg * zneg <= 100
    constraint = solver.RowConstraint(0, 100, "") # 100 is arbitrary value of C
    for i in range(n):
        constraint.SetCoefficient(zplus[i], +1.0)
        constraint.SetCoefficient(zneg[i], +1.0)



    objective = solver.Objective()
    for i in range(n):
        objective.SetCoefficient(zplus[i], 1.0)
        objective.SetCoefficient(zneg[i], -1.0)

    objective.SetMinimization()

    status = solver.Solve()

    # z1 = solver.NumVar(0, solver.infinity(), "z1")
    # z2 = solver.NumVar(0, solver.infinity(), "z2")
    # print("Number of variables =", solver.NumVariables())
    print("Exact exit Status %s" %status)
    # print(objective.Value())
    zstar = np.zeros(n)
    for i in range(n):
        zstar[i] = zplus[i].solution_value() - zneg[i].solution_value()
        # print("{} {:0.2f} {:0.2f}".format(i, zplus[i].solution_value(), zneg[i].solution_value()))
    betastar = np.linalg.lstsq(A,zstar, rcond=None)[0]

    # print(betastar)

    return betastar, calc_mu(A, betastar)
    # print(np.shape(betastar))

def getUBneg(U, beta):
    tempp = np.dot(U,beta)
    return -1* np.sum(tempp[tempp <= 0])

def solve_mu_approx(A,r1,s,solvername='SAT'):
    solver = pywraplp.Solver.CreateSolver(solvername)

    U = getFCT1(A, r1, s)
    n, d = np.shape(U)

    aa = {}
    bb = {}
    cc = {}
    dd = {}
    binaryvar = {}
    betanorm = {}
    beta = {}
    M = 1000000

    for j in range(n):
        aa[j] = solver.NumVar(0, solver.infinity(), "aa[%i]" % j)
        bb[j] = solver.NumVar(0, solver.infinity(), "bb[%i]" % j)

    for j in range(d):
        betanorm[j] = solver.NumVar(0, solver.infinity(), "betanorm[%i]" % j)
        cc[j] = solver.NumVar(0, solver.infinity(), "cc[%i]" % j)
        dd[j] = solver.NumVar(0, solver.infinity(), "dd[%i]" % j)
        binaryvar[j] = solver.IntVar(0, 1, "binaryvar[%i]" % j)
        beta[j] = solver.NumVar(0, solver.infinity(), "beta[%i]" % j)

    # betanorm = || beta||_1

    #  NEED betanorm[j] = -beta[j] or +beta[j]
    # betanorm[j] = -beta[j] + cc[j] ----- Eq 1
    # betanorm[j] = beta[j] + dd[j] ----- Eq 2
    # 0<= cc[j] <= M binaryvar ----- Eq 3
    # 0 <= dd[j] <= M(1- binaryvar) ----- Eq 4
    # set \sum_j betanorm[j] = 1 --- Eq 5

    # Eq 1 constraint
    constraint = solver.RowConstraint(0, 0, "")
    for j in range(d):
        constraint.SetCoefficient(betanorm[j], 1.0)
        constraint.SetCoefficient(beta[j], 1.0)
        constraint.SetCoefficient(cc[j], -1.0)

    # Eq 2 constraint
    constraint = solver.RowConstraint(0, 0, "")
    for j in range(d):
        constraint.SetCoefficient(betanorm[j], 1.0)
        constraint.SetCoefficient(beta[j], -1.0)
        constraint.SetCoefficient(dd[j], -1.0)

    # Eq 3 constraint
    constraint = solver.RowConstraint(0, solver.infinity(), "")
    for j in range(d):
        constraint.SetCoefficient(cc[j], -1.0)
        constraint.SetCoefficient(binaryvar[j], 1.0 * M)

    # Eq 4 constraint
    constraint = solver.RowConstraint(-M, solver.infinity(), "")
    for j in range(d):
        constraint.SetCoefficient(dd[j], -1.0)
        constraint.SetCoefficient(binaryvar[j], -1.0 * M)

    for i in range(n):
        constraint = solver.RowConstraint(0, 0, "")
        for j in range(d):
            constraint.SetCoefficient(beta[j], U[i, j])
        constraint.SetCoefficient(bb[i], +1.0)
        constraint.SetCoefficient(aa[i], -1.0)
    #
    # constraint = solver.RowConstraint(0, 0, "")
    # for j in range(d):
    #     constraint.SetCoefficient(beta[j], 1.0)
    #     constraint.SetCoefficient(dd[j], 1.0)
    #     constraint.SetCoefficient(cc[j], -1.0)

    constraint = solver.RowConstraint(1, 1, "")
    for j in range(d):
        constraint.SetCoefficient(betanorm[j], 1.0)
        # constraint.SetCoefficient(cc[j], +1.0)

    objective = solver.Objective()
    for i in range(n):
        objective.SetCoefficient(bb[i], 1.0)

    objective.SetMinimization()

    status = solver.Solve()

    # z1 = solver.NumVar(0, solver.infinity(), "z1")
    # z2 = solver.NumVar(0, solver.infinity(), "z2")
    # print("Number of variables =", solver.NumVariables())
    print("Approx exit Status %s" % status)
    t = objective.Value()
    print("t={}".format(t))

def solve_mu_approx_munteanu(A, r1, s, solvername="GLOP"):
    # get U from A here
    U = getFCT1(A, r1, s)

    solver = pywraplp.Solver.CreateSolver(solvername)
    solver.SetSolverSpecificParametersAsString(
        'termination_criteria {eps_optimal_absolute: 1e-4, eps_optimal_relative: 1e-4 }')
    # IminusP = calc_complement_proj_matrix(A)

    n, d = np.shape(U)


    aa = {}
    bb = {}
    cc = {}
    dd = {}
    beta = {}

    for j in range(n):
        aa[j] = solver.NumVar(0, solver.infinity(), "aa[%i]" % j)
        bb[j] = solver.NumVar(0, solver.infinity(), "bb[%i]" % j)

    for j in range(d):
        cc[j] = solver.NumVar(0, solver.infinity(), "cc[%i]" % j)
        dd[j] = solver.NumVar(0, solver.infinity(), "dd[%i]" % j)
        beta[j] = solver.NumVar(0, solver.infinity(), "beta[%i]" % j)

    for i in range(n):
        constraint = solver.RowConstraint(0, 0, "")
        for j in range(d):
            constraint.SetCoefficient(beta[j], U[i,j])
        constraint.SetCoefficient(bb[i], +1.0)
        constraint.SetCoefficient(aa[i], -1.0)

    constraint = solver.RowConstraint(0, 0, "")
    for j in range(d):
        constraint.SetCoefficient(beta[j], 1.0)
        constraint.SetCoefficient(dd[j], 1.0)
        constraint.SetCoefficient(cc[j], -1.0)

    constraint = solver.RowConstraint(1, solver.infinity(), "")
    for j in range(d):
        constraint.SetCoefficient(dd[j], 1.0)
        constraint.SetCoefficient(cc[j], +1.0)

    objective = solver.Objective()
    for i in range(n):
        objective.SetCoefficient(bb[i], 1.0)

    objective.SetMinimization()

    status = solver.Solve()

    # z1 = solver.NumVar(0, solver.infinity(), "z1")
    # z2 = solver.NumVar(0, solver.infinity(), "z2")
    # print("Number of variables =", solver.NumVariables())
    print("Approx exit Status %s" %status)
    t= objective.Value()
    print("t={}".format(t))
    return t
    # print("Range of approx mu {:0.2f} {:0.2f}".format( 1.0/t, d*d/t))



if __name__ == '__main__':
    r1 = 4 *2 * 2
    s = np.power(r1, 3)
    # s=256
    n = 2*s
    n = 1000 #16384
    d= 1000

    X = np.float32(np.random.rand(n,d))
    trueBeta = np.float32(np.random.randn(d))
    posterior = np.float32(expit( np.dot(X, trueBeta)) + np.float32(np.random.normal(0,1, n)))
    posIndices = posterior >= 0.5
    y = np.zeros(n) + (-1)
    y[posIndices] = +1
    # print(y)
    # A = np.dot(np.diag(y), X)
    A = np.float32(np.multiply(X, np.matrix(y).T))

    #exact
    for solvername in ['PDLP']:
        print('Solving exact...with solver = {}'.format(solvername))
        betastar, mu_exact = solve_mu_exact(A, solvername)

        print("mu = {}".format(mu_exact))
        t = getUBneg(A,betastar)
        print("UBneg = {} lb = {} ub = {}".format(t, 1.0/t, 10/t))

    # approx
    allr1 = [1024,2048,4096,8192]/4
    for r1 in allr1:
        s = r1* 4 # np.power(r1, 3)
        U = getFCT1(A, r1, s)
        # U=A
        t = solve_mu_approx(U,r1,s)
        print("lb = {} ub = {}".format(1.0 / t, 10 / t),flush=True)

